import numpy as np

from utils.utils import sigmoid, xywh2xyxy, iou


def decode_outputs(outputs, input_shape):
    """
    解码函数
    参考: https://github.com/bubbliiiing/yolox-pytorch
    :param outputs: 网络推理结果
    :param input_shape: 输入图像尺寸
    :return: 解码结果
    """
    grids = []
    strides = []
    hw = [x.shape[-2:] for x in outputs]
    outputs = np.concatenate([x.reshape(x.shape[:-3] + (x.shape[1], -1)) for x in outputs], axis=2).transpose([0, 2, 1])

    outputs[:, :, 4:] = sigmoid(outputs[:, :, 4:])

    for h, w in hw:
        grid_x = np.array(np.meshgrid(np.arange(h)))
        grid_x = grid_x.repeat(h, axis=0)
        grid_y = np.array(np.meshgrid(np.arange(w)))
        grid_y = grid_y.repeat(w, axis=0).transpose(1, 0)

        grid = np.stack([grid_x, grid_y], axis=2).reshape((1, -1, 2))
        shape = grid.shape[:2]
        grids.append(grid)
        strides.append(np.full((shape[0], shape[1], 1), input_shape[0] / h))

    grids = np.concatenate(grids, axis=1)
    strides = np.concatenate(strides, axis=1)

    outputs[..., :2] = (outputs[..., :2] + grids) * strides
    outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * strides

    outputs[..., [0, 2]] = outputs[..., [0, 2]] / input_shape[1]
    outputs[..., [1, 3]] = outputs[..., [1, 3]] / input_shape[0]
    return outputs


def non_max_suppression(prediction, num_classes, conf_thre=0.5, iou_thre=0.5):
    """
    非极大值抑制
    :param prediction: 检测结果
    :param num_classes: 分类数量
    :param conf_thre: 置信度阈值
    :param iou_thre: iou阈值
    :return: 处理结果
    """
    prediction = xywh2xyxy(prediction)
    prediction = prediction[0]
    class_conf = np.max(prediction[:, 5:5 + num_classes], axis=1, keepdims=True)
    class_pred = np.argmax(prediction[:, 5: 5 + num_classes], axis=1, keepdims=True)

    conf_mask = (prediction[:, 4] * class_conf[:, 0] >= conf_thre).squeeze()

    detections = np.concatenate([prediction[:, :5], class_conf, class_pred.astype(np.float32)], axis=1)
    detections = detections[conf_mask]
    detections = sorted(detections, key=lambda x: x[4], reverse=True)

    bbox_afer_nms = _nms(detections, iou_thre)

    return bbox_afer_nms


def _nms(detections, iou_thre):
    bbox_afer_nms = []

    # 非极大值抑制
    while detections:
        chosen_box = detections.pop(0)
        detections = [
            box
            for box in detections
            if box[-1] != chosen_box[-1]
               or iou(
                chosen_box,
                box
            ) < iou_thre
        ]
        bbox_afer_nms.append(chosen_box)
    return bbox_afer_nms


def get_real_boxes(boxes, image_shape):
    """
    获取真实检测框
    :param boxes: 检测框
    :param image_shape: 输入图像尺寸
    :return: 真实检测框
    """
    for box in boxes:
        box[:4] = np.maximum(box[:4], 0)
        box[:4] = np.minimum(box[:4], 1)

        s = np.max(image_shape)
        dim = np.argmin(image_shape)
        box[:4] *= s
        scale = np.zeros(2)
        scale[dim] = (s - image_shape[dim]) // 2

        box[0] = np.maximum(box[0] - scale[0], 0)
        box[1] = np.maximum(box[1] - scale[1], 0)
        box[2] = np.maximum(box[2] - scale[0], 0)
        box[3] = np.maximum(box[3] - scale[1], 0)
    return boxes
